from __future__ import annotations

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable
import dataclasses
import itertools
import textwrap
from typing import TypedDict, Union

import google.protobuf.json_format
import google.api_core.exceptions

from google.ai import generativelanguage as glm
from google.generativeai import string_utils

__all__ = [
    "AsyncGenerateContentResponse",
    "BlockedPromptException",
    "StopCandidateException",
    "IncompleteIterationError",
    "BrokenResponseError",
    "GenerationConfigDict",
    "GenerationConfigType",
    "GenerationConfig",
    "GenerateContentResponse",
]


class BlockedPromptException(Exception):
    pass


class StopCandidateException(Exception):
    pass


class IncompleteIterationError(Exception):
    pass


class BrokenResponseError(Exception):
    pass


class GenerationConfigDict(TypedDict):
    # TODO(markdaoust): Python 3.11+ use `NotRequired`, ref: https://peps.python.org/pep-0655/
    candidate_count: int
    stop_sequences: Iterable[str]
    max_output_tokens: int
    temperature: float


@dataclasses.dataclass
class GenerationConfig:
    """A simple dataclass used to configure the generation parameters of `GenerativeModel.generate_content`.

    Attributes:
        candidate_count:
            Number of generated responses to return.
        stop_sequences:
            The set of character sequences (up
            to 5) that will stop output generation. If
            specified, the API will stop at the first
            appearance of a stop sequence. The stop sequence
            will not be included as part of the response.
        max_output_tokens:
            The maximum number of tokens to include in a
            candidate.

            If unset, this will default to output_token_limit specified
            in the model's specification.
        temperature:
            Controls the randomness of the output. Note: The

            default value varies by model, see the `Model.temperature`
            attribute of the `Model` returned the `genai.get_model`
            function.

            Values can range from [0.0,1.0], inclusive. A value closer
            to 1.0 will produce responses that are more varied and
            creative, while a value closer to 0.0 will typically result
            in more straightforward responses from the model.
        top_p:
            Optional. The maximum cumulative probability of tokens to
            consider when sampling.

            The model uses combined Top-k and nucleus sampling.

            Tokens are sorted based on their assigned probabilities so
            that only the most likely tokens are considered. Top-k
            sampling directly limits the maximum number of tokens to
            consider, while Nucleus sampling limits number of tokens
            based on the cumulative probability.

            Note: The default value varies by model, see the
            `Model.top_p` attribute of the `Model` returned the
            `genai.get_model` function.

        top_k (int):
            Optional. The maximum number of tokens to consider when
            sampling.

            The model uses combined Top-k and nucleus sampling.

            Top-k sampling considers the set of `top_k` most probable
            tokens. Defaults to 40.

            Note: The default value varies by model, see the
            `Model.top_k` attribute of the `Model` returned the
            `genai.get_model` function.
    """

    candidate_count: int | None = None
    stop_sequences: Iterable[str] | None = None
    max_output_tokens: int | None = None
    temperature: float | None = None
    top_p: float | None = None
    top_k: int | None = None


GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig]


def to_generation_config_dict(generation_config: GenerationConfigType):
    if generation_config is None:
        return {}
    elif isinstance(generation_config, glm.GenerationConfig):
        return type(generation_config).to_dict(generation_config)  # pytype: disable=attribute-error
    elif isinstance(generation_config, GenerationConfig):
        generation_config = dataclasses.asdict(generation_config)
        return {key: value for key, value in generation_config.items() if value is not None}
    elif hasattr(generation_config, "keys"):
        return dict(generation_config)
    else:
        raise TypeError(
            "Did not understand `generation_config`, expected a `dict` or"
            f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:"
            f" {generation_config}"
        )


def _join_citation_metadatas(
    citation_metadatas: Iterable[glm.CitationMetadata],
):
    citation_metadatas = list(citation_metadatas)
    return citation_metadatas[-1]


def _join_safety_ratings_lists(
    safety_ratings_lists: Iterable[list[glm.SafetyRating]],
):
    ratings = {}
    blocked = collections.defaultdict(list)

    for safety_ratings_list in safety_ratings_lists:
        for rating in safety_ratings_list:
            ratings[rating.category] = rating.probability
            blocked[rating.category].append(rating.blocked)

    blocked = {category: any(blocked) for category, blocked in blocked.items()}

    safety_list = []
    for (category, probability), blocked in zip(ratings.items(), blocked.values()):
        safety_list.append(
            glm.SafetyRating(category=category, probability=probability, blocked=blocked)
        )

    return safety_list


def _join_contents(contents: Iterable[glm.Content]):
    contents = tuple(contents)
    roles = [c.role for c in contents if c.role]
    if roles:
        role = roles[0]
    else:
        role = ""

    parts = []
    for content in contents:
        parts.extend(content.parts)

    merged_parts = [parts.pop(0)]
    for part in parts:
        if not merged_parts[-1].text:
            merged_parts.append(part)
            continue

        if not part.text:
            merged_parts.append(part)
            continue

        merged_part = glm.Part(merged_parts[-1])
        merged_part.text += part.text
        merged_parts[-1] = merged_part

    return glm.Content(
        role=role,
        parts=merged_parts,
    )


def _join_candidates(candidates: Iterable[glm.Candidate]):
    candidates = tuple(candidates)

    index = candidates[0].index  # These should all be the same.

    return glm.Candidate(
        index=index,
        content=_join_contents([c.content for c in candidates]),
        finish_reason=candidates[-1].finish_reason,
        safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
        citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
    )


def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]):
    # Assuming that is a candidate ends, it is no longer returned in the list of
    # candidates and that's why candidates have an index
    candidates = collections.defaultdict(list)
    for candidate_list in candidate_lists:
        for candidate in candidate_list:
            candidates[candidate.index].append(candidate)

    new_candidates = []
    for index, candidate_parts in sorted(candidates.items()):
        new_candidates.append(_join_candidates(candidate_parts))

    return new_candidates


def _join_prompt_feedbacks(
    prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback],
):
    # Always return the first prompt feedback.
    return next(iter(prompt_feedbacks))


def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]):
    return glm.GenerateContentResponse(
        candidates=_join_candidate_lists(c.candidates for c in chunks),
        prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
    )


_INCOMPLETE_ITERATION_MESSAGE = """\
Please let the response complete iteration before accessing the final accumulated
attributes (or call `response.resolve()`)"""


class BaseGenerateContentResponse:
    def __init__(
        self,
        done: bool,
        iterator: None
        | Iterable[glm.GenerateContentResponse]
        | AsyncIterable[glm.GenerateContentResponse],
        result: glm.GenerateContentResponse,
        chunks: Iterable[glm.GenerateContentResponse],
    ):
        self._done = done
        self._iterator = iterator
        self._result = result
        self._chunks = list(chunks)
        if result.prompt_feedback.block_reason:
            self._error = BlockedPromptException(result)
        else:
            self._error = None

    @property
    def candidates(self):
        """The list of candidate responses.

        Raises:
            IncompleteIterationError: With `stream=True` if iteration over the stream was not completed.
        """
        if not self._done:
            raise IncompleteIterationError(_INCOMPLETE_ITERATION_MESSAGE)
        return self._result.candidates

    @property
    def parts(self):
        """A quick accessor equivalent to `self.candidates[0].parts`

        Raises:
            ValueError: If the candidate list does not contain exactly one candidate.
        """
        candidates = self.candidates
        if not candidates:
            raise ValueError(
                "The `response.parts` quick accessor only works for a single candidate, "
                "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked."
            )
        if len(candidates) > 1:
            raise ValueError(
                "The `response.parts` quick accessor only works with a "
                "single candidate. With multiple candidates use "
                "result.candidates[index].text"
            )
        parts = candidates[0].content.parts
        return parts

    @property
    def text(self):
        """A quick accessor equivalent to `self.candidates[0].parts[0].text`

        Raises:
            ValueError: If the candidate list or parts list does not contain exactly one entry.
        """
        parts = self.parts
        if len(parts) > 1 or "text" not in parts[0]:
            raise ValueError(
                "The `response.text` quick accessor only works for "
                "simple (single-`Part`) text responses. This response "
                "contains multiple `Parts`. Use the `result.parts` "
                "accessor or the full "
                "`result.candidates[index].content.parts` lookup "
                "instead"
            )
        return parts[0].text

    @property
    def prompt_feedback(self):
        return self._result.prompt_feedback


@contextlib.contextmanager
def rewrite_stream_error():
    try:
        yield
    except (google.protobuf.json_format.ParseError, AttributeError) as e:
        raise google.api_core.exceptions.BadRequest(
            "Unknown error trying to retrieve streaming response. "
            "Please retry with `stream=False` for more details."
        )


GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content_async` method.

    These are returned by `GenerativeModel.generate_content_async` and `ChatSession.send_message_async`.
    This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback`
    and `candidates` attributes. This class adds several quick accessors for common use cases.

    The same object type is returned for both `stream=True/False`.

    ### Streaming

    When you pass `stream=True` to `GenerativeModel.generate_content_async` or `ChatSession.send_message_async`,
    iterate over this object to receive chunks of the response:

    ```
    response = model.generate_content_async(..., stream=True):
    async for chunk in response:
      print(chunk.text)
    ```

    `AsyncGenerateContentResponse.prompt_feedback` is available immediately but
    `AsyncGenerateContentResponse.candidates`, and all the attributes derived from them (`.text`, `.parts`),
    are only available after the iteration is complete.
    """

ASYNC_GENERATE_CONTENT_RESPONSE_DOC = (
    """This is the async version of `genai.GenerateContentResponse`."""
)


@string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC)
class GenerateContentResponse(BaseGenerateContentResponse):
    @classmethod
    def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]):
        iterator = iter(iterator)
        with rewrite_stream_error():
            response = next(iterator)

        return cls(
            done=False,
            iterator=iterator,
            result=response,
            chunks=[response],
        )

    @classmethod
    def from_response(cls, response: glm.GenerateContentResponse):
        return cls(
            done=True,
            iterator=None,
            result=response,
            chunks=[response],
        )

    def __iter__(self):
        # This is not thread safe.
        if self._done:
            for chunk in self._chunks:
                yield GenerateContentResponse.from_response(chunk)
            return

        # Always have the next chunk available.
        if len(self._chunks) == 0:
            self._chunks.append(next(self._iterator))

        for n in itertools.count():
            if self._error:
                raise self._error

            if n >= len(self._chunks) - 1:
                # Look ahead for a new item, so that you know the stream is done
                # when you yield the last item.
                if self._done:
                    return

                try:
                    item = next(self._iterator)
                except StopIteration:
                    self._done = True
                except Exception as e:
                    self._error = e
                    self._done = True
                else:
                    self._chunks.append(item)
                    self._result = _join_chunks([self._result, item])

            item = self._chunks[n]

            item = GenerateContentResponse.from_response(item)
            yield item

    def resolve(self):
        if self._done:
            return

        for _ in self:
            pass


@string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC)
class AsyncGenerateContentResponse(BaseGenerateContentResponse):
    @classmethod
    async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]):
        iterator = aiter(iterator)  # type: ignore
        with rewrite_stream_error():
            response = await anext(iterator)  # type: ignore

        return cls(
            done=False,
            iterator=iterator,
            result=response,
            chunks=[response],
        )

    @classmethod
    def from_response(cls, response: glm.GenerateContentResponse):
        return cls(
            done=True,
            iterator=None,
            result=response,
            chunks=[response],
        )

    async def __aiter__(self):
        # This is not thread safe.
        if self._done:
            for chunk in self._chunks:
                yield GenerateContentResponse.from_response(chunk)
            return

        # Always have the next chunk available.
        if len(self._chunks) == 0:
            self._chunks.append(await anext(self._iterator))  # type: ignore

        for n in itertools.count():
            if self._error:
                raise self._error

            if n >= len(self._chunks) - 1:
                # Look ahead for a new item, so that you know the stream is done
                # when you yield the last item.
                if self._done:
                    return

                try:
                    item = await anext(self._iterator)  # type: ignore
                except StopAsyncIteration:
                    self._done = True
                except Exception as e:
                    self._error = e
                    self._done = True
                else:
                    self._chunks.append(item)
                    self._result = _join_chunks([self._result, item])

            item = self._chunks[n]

            item = GenerateContentResponse.from_response(item)
            yield item

    async def resolve(self):
        if self._done:
            return

        async for _ in self:
            pass
